{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "import h2o4gpu\n",
    "import h2o4gpu.util.import_data as io\n",
    "import h2o4gpu.util.metrics as metrics\n",
    "from tabulate import tabulate\n",
    "import pandas as pd"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Reading Data with Pandas\n",
      "(999, 9733)\n",
      "Original m=999 n=9732\n",
      "Size of Train rows=800 & valid rows=199\n",
      "Size of Train cols=9732 valid cols=9732\n",
      "Size of Train cols=9733 & valid cols=9733 after adding intercept column\n"
     ]
    }
   ],
   "source": [
    "\"\"\"\n",
    "Import Data for H2O GPU Edition\n",
    "\n",
    "This function will read in data and prepare it for H2O4GPU's GLM solver\n",
    "\n",
    "Parameters\n",
    "----------\n",
    "data_path : str\n",
    "             A path to a dataset (The dataset needs to be all numeric)\n",
    "use_pandas : bool\n",
    "              Indicate if Pandas should be used to parse\n",
    "intercept : bool\n",
    "              Indicate if intercept term is needed\n",
    "valid_fraction : float\n",
    "                  Percentage of dataset reserved for a validation set\n",
    "classification : bool\n",
    "                  Classification problem?\n",
    "Returns\n",
    "-------\n",
    "If valid_fraction > 0 it will return the following:\n",
    "    train_x: numpy array of train input variables\n",
    "    train_y: numpy array of y variable\n",
    "    valid_x: numpy array of valid input variables\n",
    "    valid_y: numpy array of valid y variable\n",
    "    family : string that would either be \"logistic\" if classification is set to True, otherwise \"elasticnet\"\n",
    "If valid_fraction == 0 it will return the following:\n",
    "    train_x: numpy array of train input variables\n",
    "    train_y: numpy array of y variable\n",
    "    family : string that would either be \"logistic\" if classification is set to True, otherwise \"elasticnet\"\n",
    "\"\"\"\n",
    "import os\n",
    "if not os.path.exists(\"ipums_1k.csv\"):\n",
    "    !wget https://s3.amazonaws.com/h2o-public-test-data/h2o4gpu/open_data/ipums_1k.csv\n",
    "train_x,train_y,valid_x,valid_y,family=io.import_data(data_path=\"ipums_1k.csv\", \n",
    "                                                        use_pandas=True, \n",
    "                                                        intercept=True,\n",
    "                                                        valid_fraction=0.2,\n",
    "                                                        classification=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Setting up solver\n",
      "Running h2o4gpu Ridge Regression\n"
     ]
    }
   ],
   "source": [
    "\"\"\"\n",
    "Set up instance of H2O4GPU's Ridge solver with default parameters\n",
    "\n",
    "Need to pass in `family` to indicate problem type to solve\n",
    "\"\"\"\n",
    "print(\"Setting up solver\")\n",
    "model = h2o4gpu.Ridge()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Solving\n",
      "CPU times: user 2.84 s, sys: 620 ms, total: 3.46 s\n",
      "Wall time: 3.42 s\n",
      "Done Solving\n"
     ]
    }
   ],
   "source": [
    "\"\"\"\n",
    "Fit Ridge Solver\n",
    "\"\"\"\n",
    "print(\"Solving\")\n",
    "%time model.fit(train_x, train_y)\n",
    "print(\"Done Solving\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Predictions per alpha\n",
      "[[  1.44140804e+00   3.95907745e+02   5.85950859e+04   1.35538736e-02\n",
      "    2.24011857e-03  -9.94564220e-03   2.08838610e-03   1.52874296e-03\n",
      "    8.47863592e-03   8.15556982e+03  -1.24143763e-03   9.60603240e-04\n",
      "    6.25534402e-03   1.57668184e+04  -1.02974465e-02   1.85184082e+04\n",
      "    4.43184264e-02  -2.37068860e-04   2.32363149e-04   2.67238274e+01\n",
      "   -7.34141785e+02  -1.36826131e-02  -1.16951093e-02   1.73626280e+00\n",
      "   -2.37373356e-03   1.56881176e-02   4.45240736e+00   1.97632122e+01\n",
      "    7.47358240e-03   3.67370062e-03  -2.05536117e-03   6.71235046e+02\n",
      "    1.45517930e-03   1.65146042e+02   1.65146515e+02   1.84123810e+02\n",
      "   -3.70902475e-03  -5.97808510e-03  -6.31310628e-04  -1.80768259e-02\n",
      "   -5.15626045e-03  -5.77098923e-03   2.64548877e+03   2.64548047e+03\n",
      "    4.90870886e-03   4.42881574e-04   1.39958085e-03   3.42940601e+03\n",
      "   -7.33041016e+02   1.61363633e+04  -3.41391144e-03  -4.12627310e-03\n",
      "    1.37159340e-02   1.27209687e+00  -3.47321550e-03   1.39156298e-03\n",
      "   -1.08198857e+00  -1.08054793e+00   7.36660231e-03   1.27198594e-03\n",
      "    1.15174512e+04  -1.68164226e-03   3.14791775e+00   3.15148783e+00\n",
      "    1.63198542e-03  -1.20079180e-03   4.72718201e+02   6.03787089e-03\n",
      "    1.78480102e-03  -1.71055626e-02  -5.59950294e-03   9.04828453e+00\n",
      "   -1.35061739e-03  -1.61241123e-03   7.59866238e-02  -8.00079480e-03\n",
      "   -5.10328030e-03   2.87667220e-03   3.66678906e+04  -1.57263945e-03\n",
      "    4.73849848e-03   4.13413625e-03   5.17100934e-03   7.68598216e-03\n",
      "   -1.66405342e-03  -7.66057812e-04   3.99940381e+03   2.44887010e-03\n",
      "    2.10839082e-02   4.27258899e-03  -7.33038574e+02  -5.02515165e-03\n",
      "   -9.10319271e-04   2.67171345e+01   6.74435287e-05  -5.15503297e-03\n",
      "    1.36743963e-03  -3.52691929e-03  -3.97895835e-03   3.45039111e+03\n",
      "    1.81764066e+00   2.65978253e-03   2.65543000e-03  -1.84787735e-02\n",
      "    7.12960645e+03  -1.05922678e-02   2.78913379e+03  -3.97800095e-03\n",
      "   -2.54600658e-04  -1.49720686e-03  -2.34268047e-03   4.02539223e-03\n",
      "    4.23551941e+01  -6.96316082e-03  -5.86435010e+03  -2.16082903e-03\n",
      "    3.00719077e-03   3.07656787e+03   3.46287689e-03   4.00112158e+03\n",
      "    5.38996315e+00   1.83340703e+04   1.89744492e+01  -2.66905699e-05\n",
      "    1.80450213e+00   6.71982765e-03   6.55304489e+01  -1.69353164e-03\n",
      "   -2.13628798e-03   3.96103933e-02   2.07309931e-04  -2.17599329e-03\n",
      "    2.05253577e-03   1.54996109e+01  -1.28937429e-02  -1.80694449e-04\n",
      "    4.72712189e+02   2.82231253e-03   7.85103161e-03  -9.02463868e-03\n",
      "   -3.95222148e-03   1.13776670e+04   1.13776582e+04  -2.77159060e-03\n",
      "   -3.36039439e-03   3.53539572e-03  -3.01774871e-03   2.06896523e-03\n",
      "    2.46589453e+04   3.42941382e+03   1.01385809e+03  -7.50974054e-04\n",
      "    1.70926237e+00   1.74063313e+00   1.74684048e+00   5.92225464e-03\n",
      "   -2.22545350e-03  -6.94452319e-05  -4.07080054e-02   2.93834694e-03\n",
      "    5.52046904e-03   1.30073643e+04  -3.94681701e-03  -4.50224336e-03\n",
      "    1.01062869e-04   4.92509812e-01   4.86607790e-01  -1.50919310e-03\n",
      "    2.61710677e-02   1.88733354e-01   5.45421755e-03  -1.61218073e-03\n",
      "    8.46655294e-03  -2.21211556e-03  -3.13416123e-03  -8.23132601e-03\n",
      "   -1.34511935e-02   1.15063855e+03   5.53571098e-02   7.24710058e-04\n",
      "   -4.42641973e-03  -1.02149602e-02   3.29264365e-02  -1.89240407e-02\n",
      "    8.53432622e-03   1.58685297e-02   2.08236706e-02  -6.68194331e-03\n",
      "    7.32387183e-04   3.95910522e+02   3.94486636e-03  -3.98397027e-03\n",
      "    1.35999569e-03   2.68641597e-05   1.28144713e-03   4.70679248e+03\n",
      "    1.03962617e-02   7.29913171e-03   7.12201223e-02]]\n"
     ]
    }
   ],
   "source": [
    "\"\"\"\n",
    "Make predictions on validation set\n",
    "\"\"\"\n",
    "print(\"Predictions per alpha\")\n",
    "preds = model.predict(valid_x)\n",
    "print(preds)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "env",
   "language": "python",
   "name": "env"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
